{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simple Binary Classification with defaults\n",
    "\n",
    "In this notebook we will use the Adult Census dataset. Download the data from [here](https://www.kaggle.com/wenruliu/adult-income-dataset/downloads/adult.csv/2)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor\n",
    "from pytorch_widedeep.models import Wide, DeepDense, WideDeep\n",
    "from pytorch_widedeep.metrics import Accuracy, Precision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>educational-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>gender</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "      <th>income</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>25</td>\n",
       "      <td>Private</td>\n",
       "      <td>226802</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>89814</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Farming-fishing</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>50</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>28</td>\n",
       "      <td>Local-gov</td>\n",
       "      <td>336951</td>\n",
       "      <td>Assoc-acdm</td>\n",
       "      <td>12</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Protective-serv</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&gt;50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>44</td>\n",
       "      <td>Private</td>\n",
       "      <td>160323</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Husband</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>7688</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&gt;50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>18</td>\n",
       "      <td>?</td>\n",
       "      <td>103497</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>?</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>White</td>\n",
       "      <td>Female</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>30</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age  workclass  fnlwgt     education  educational-num      marital-status  \\\n",
       "0   25    Private  226802          11th                7       Never-married   \n",
       "1   38    Private   89814       HS-grad                9  Married-civ-spouse   \n",
       "2   28  Local-gov  336951    Assoc-acdm               12  Married-civ-spouse   \n",
       "3   44    Private  160323  Some-college               10  Married-civ-spouse   \n",
       "4   18          ?  103497  Some-college               10       Never-married   \n",
       "\n",
       "          occupation relationship   race  gender  capital-gain  capital-loss  \\\n",
       "0  Machine-op-inspct    Own-child  Black    Male             0             0   \n",
       "1    Farming-fishing      Husband  White    Male             0             0   \n",
       "2    Protective-serv      Husband  White    Male             0             0   \n",
       "3  Machine-op-inspct      Husband  Black    Male          7688             0   \n",
       "4                  ?    Own-child  White  Female             0             0   \n",
       "\n",
       "   hours-per-week native-country income  \n",
       "0              40  United-States  <=50K  \n",
       "1              50  United-States  <=50K  \n",
       "2              40  United-States   >50K  \n",
       "3              40  United-States   >50K  \n",
       "4              30  United-States  <=50K  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('data/adult/adult.csv.zip')\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>educational_num</th>\n",
       "      <th>marital_status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>gender</th>\n",
       "      <th>capital_gain</th>\n",
       "      <th>capital_loss</th>\n",
       "      <th>hours_per_week</th>\n",
       "      <th>native_country</th>\n",
       "      <th>income_label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>25</td>\n",
       "      <td>Private</td>\n",
       "      <td>226802</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>89814</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Farming-fishing</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>50</td>\n",
       "      <td>United-States</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>28</td>\n",
       "      <td>Local-gov</td>\n",
       "      <td>336951</td>\n",
       "      <td>Assoc-acdm</td>\n",
       "      <td>12</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Protective-serv</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>44</td>\n",
       "      <td>Private</td>\n",
       "      <td>160323</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Husband</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>7688</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>18</td>\n",
       "      <td>?</td>\n",
       "      <td>103497</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>?</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>White</td>\n",
       "      <td>Female</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>30</td>\n",
       "      <td>United-States</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age  workclass  fnlwgt     education  educational_num      marital_status  \\\n",
       "0   25    Private  226802          11th                7       Never-married   \n",
       "1   38    Private   89814       HS-grad                9  Married-civ-spouse   \n",
       "2   28  Local-gov  336951    Assoc-acdm               12  Married-civ-spouse   \n",
       "3   44    Private  160323  Some-college               10  Married-civ-spouse   \n",
       "4   18          ?  103497  Some-college               10       Never-married   \n",
       "\n",
       "          occupation relationship   race  gender  capital_gain  capital_loss  \\\n",
       "0  Machine-op-inspct    Own-child  Black    Male             0             0   \n",
       "1    Farming-fishing      Husband  White    Male             0             0   \n",
       "2    Protective-serv      Husband  White    Male             0             0   \n",
       "3  Machine-op-inspct      Husband  Black    Male          7688             0   \n",
       "4                  ?    Own-child  White  Female             0             0   \n",
       "\n",
       "   hours_per_week native_country  income_label  \n",
       "0              40  United-States             0  \n",
       "1              50  United-States             0  \n",
       "2              40  United-States             1  \n",
       "3              40  United-States             1  \n",
       "4              30  United-States             0  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# For convenience, we'll replace '-' with '_'\n",
    "df.columns = [c.replace(\"-\", \"_\") for c in df.columns]\n",
    "# binary target\n",
    "df['income_label'] = (df[\"income\"].apply(lambda x: \">50K\" in x)).astype(int)\n",
    "df.drop('income', axis=1, inplace=True)\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preparing the data\n",
    "\n",
    "Have a look to notebooks one and two if you want to get a good understanding of the next few lines of code (although there is no need to use the package)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "wide_cols = ['education', 'relationship','workclass','occupation','native_country','gender']\n",
    "crossed_cols = [('education', 'occupation'), ('native_country', 'occupation')]\n",
    "cat_embed_cols = [('education',16), ('relationship',8), ('workclass',16), ('occupation',16),('native_country',16)]\n",
    "continuous_cols = [\"age\",\"hours_per_week\"]\n",
    "target_col = 'income_label'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TARGET\n",
    "target = df[target_col].values\n",
    "\n",
    "# WIDE\n",
    "preprocess_wide = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)\n",
    "X_wide = preprocess_wide.fit_transform(df)\n",
    "\n",
    "# DEEP\n",
    "preprocess_deep = DensePreprocessor(embed_cols=cat_embed_cols, continuous_cols=continuous_cols)\n",
    "X_deep = preprocess_deep.fit_transform(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[  1  17  23 ...  89  91 316]\n",
      " [  2  18  23 ...  89  92 317]\n",
      " [  3  18  24 ...  89  93 318]\n",
      " ...\n",
      " [  2  20  23 ...  90 103 323]\n",
      " [  2  17  23 ...  89 103 323]\n",
      " [  2  21  29 ...  90 115 324]]\n",
      "(48842, 8)\n"
     ]
    }
   ],
   "source": [
    "print(X_wide)\n",
    "print(X_wide.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 0.          0.          0.         ...  0.         -0.99512893\n",
      "  -0.03408696]\n",
      " [ 1.          1.          0.         ...  0.         -0.04694151\n",
      "   0.77292975]\n",
      " [ 2.          1.          1.         ...  0.         -0.77631645\n",
      "  -0.03408696]\n",
      " ...\n",
      " [ 1.          3.          0.         ...  0.          1.41180837\n",
      "  -0.03408696]\n",
      " [ 1.          0.          0.         ...  0.         -1.21394141\n",
      "  -1.64812038]\n",
      " [ 1.          4.          6.         ...  0.          0.97418341\n",
      "  -0.03408696]]\n",
      "(48842, 7)\n"
     ]
    }
   ],
   "source": [
    "print(X_deep)\n",
    "print(X_deep.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Defining the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)\n",
    "deepdense = DeepDense(hidden_layers=[64,32], \n",
    "                      deep_column_idx=preprocess_deep.deep_column_idx,\n",
    "                      embed_input=preprocess_deep.embeddings_input,\n",
    "                      continuous_cols=continuous_cols)\n",
    "model = WideDeep(wide=wide, deepdense=deepdense)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "WideDeep(\n",
       "  (wide): Wide(\n",
       "    (wide_linear): Embedding(797, 1, padding_idx=0)\n",
       "  )\n",
       "  (deepdense): Sequential(\n",
       "    (0): DeepDense(\n",
       "      (embed_layers): ModuleDict(\n",
       "        (emb_layer_education): Embedding(17, 16)\n",
       "        (emb_layer_native_country): Embedding(43, 16)\n",
       "        (emb_layer_occupation): Embedding(16, 16)\n",
       "        (emb_layer_relationship): Embedding(7, 8)\n",
       "        (emb_layer_workclass): Embedding(10, 16)\n",
       "      )\n",
       "      (embed_dropout): Dropout(p=0.0, inplace=False)\n",
       "      (dense): Sequential(\n",
       "        (dense_layer_0): Sequential(\n",
       "          (0): Linear(in_features=74, out_features=64, bias=True)\n",
       "          (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "          (2): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (dense_layer_1): Sequential(\n",
       "          (0): Linear(in_features=64, out_features=32, bias=True)\n",
       "          (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "          (2): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (1): Linear(in_features=32, out_features=1, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see, the model is not particularly complex. In mathematical terms (Eq 3 in the [original paper](https://arxiv.org/pdf/1606.07792.pdf)): \n",
    "\n",
    "$$\n",
    "pred = \\sigma(W^{T}_{wide}[x, \\phi(x)] + W^{T}_{deep}a_{deep}^{(l_f)} +  b) \n",
    "$$ \n",
    "\n",
    "\n",
    "The architecture above will output the 1st and the second term in the parenthesis. `WideDeep` will then add them and apply an activation function (`sigmoid` in this case). For more details, please refer to the paper."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compiling and Running/Fitting\n",
    "Once the model is built, we just need to compile it and run it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(method='binary', metrics=[Accuracy, Precision])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 0/611 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████| 611/611 [00:04<00:00, 128.34it/s, loss=0.655, metrics={'acc': 0.6487, 'prec': 0.2352}]\n",
      "valid: 100%|██████████| 153/153 [00:00<00:00, 173.04it/s, loss=0.517, metrics={'acc': 0.6659, 'prec': 0.2524}]\n",
      "epoch 2: 100%|██████████| 611/611 [00:04<00:00, 133.99it/s, loss=0.466, metrics={'acc': 0.7725, 'prec': 0.5456}]\n",
      "valid: 100%|██████████| 153/153 [00:00<00:00, 171.03it/s, loss=0.433, metrics={'acc': 0.7765, 'prec': 0.5598}]\n",
      "epoch 3: 100%|██████████| 611/611 [00:04<00:00, 132.06it/s, loss=0.413, metrics={'acc': 0.803, 'prec': 0.6451}] \n",
      "valid: 100%|██████████| 153/153 [00:00<00:00, 172.22it/s, loss=0.4, metrics={'acc': 0.8045, 'prec': 0.648}]   \n",
      "epoch 4: 100%|██████████| 611/611 [00:04<00:00, 131.57it/s, loss=0.39, metrics={'acc': 0.8181, 'prec': 0.6836}] \n",
      "valid: 100%|██████████| 153/153 [00:00<00:00, 169.62it/s, loss=0.384, metrics={'acc': 0.8195, 'prec': 0.6841}]\n",
      "epoch 5: 100%|██████████| 611/611 [00:04<00:00, 130.85it/s, loss=0.378, metrics={'acc': 0.8247, 'prec': 0.6941}]\n",
      "valid: 100%|██████████| 153/153 [00:00<00:00, 171.85it/s, loss=0.376, metrics={'acc': 0.8254, 'prec': 0.6946}]\n"
     ]
    }
   ],
   "source": [
    "model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=5, batch_size=64, val_split=0.2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see, you can run a wide and deep model in just a few lines of code"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
